Constrained decoding with weighted automata1

I want to share a take on grammar-constrained generation I’ve been working on for a while. This post is a preliminary writeup: enough detail to communicate the core “read the (G)LR stack with a weighted automaton” idea (and timestamp it). I’ll likely follow up with a post on the various tricks you need to actually build the damn things without the compiler exploding.

The problem

You have a grammar (often a JSON Schema, or “JSON-but-with-some-fields”, or a programming-language subset) and you want the LLM to produce grammatically valid output.

At each step, the LLM outputs logits over the vocabulary VV. You sample a token. Append it to the output. Repeat.

Constrained decoding is: don’t sample tokens that would make the output grammatically invalid.

So we maintain a constraint state and, for each step, compute a mask MVM \subseteq V of allowed tokens, and apply it to the logits.

No more broken JSON.

A precise way to say “allowed” (this matters later):

A token vVv \in V is valid given the current prefix ww if there exists some continuation uu such that wvuwvu is in the grammar language.

Equivalently, the mask is

GetMask(w)={vV:u.  wvuL(G)}.\text{GetMask}(w) = \{ v \in V : \exists u.\; wvu \in L(G) \}.

That “u\exists u” is the whole game: we’re not asking “does this token complete a valid document right now?”, we’re asking “could it still lead to a valid completion?”

Two primitive operations

Think of constrained decoding as two functions:

At runtime the loop looks like:

loop:
  parallel:
    GPU: logits
    CPU: mask = get_mask(state)
  apply mask to logits
  sample token
  state = commit(state, token)

Nothing fancy, except: get_mask is on the critical path of decoding.

Why worst-case latency matters

LLM inference is batched: multiple user requests run on the GPU at once, and each step wants a mask at the same time.

If you want to generate ~1k tokens/sec, you have ~1ms to produce a mask. And if you miss even a single one—if a mask takes 2ms instead of 1—then in a batched setting it screws up scheduling: the GPU sits waiting on a straggler CPU mask, and suddenly your p99 latency is trash.

Masks can’t be late; they have to arrive on time every step. It’s worst-case performance that matters.

It’s not “can you do masking fast on average?”
It’s “can you do it fast every single step?”

Commit ~ incremental parsing

The commit op is basically incremental parsing: you’re incrementally parsing a growing prefix.

Many battle-tested systems like tree-sitter and Lezer use Generalized LR (GLR) parsing with a graph-structured stack (GSS).

So commit is not the “new idea” here. commit is: use a decent incremental GLR parser and keep it compact.

The interesting part is get_mask.

Glossary

LLM Token
An element of the model vocabulary—a BPE token with an ID and byte string. ~200k of these in modern tokenizers.
Grammar Terminal
The symbols your parser actually consumes. Might be bytes, characters, or lexer tokens like --, IDENT, STRING.
Parser State ID
An integer identifying the LR automaton's current state. The parser stack is a sequence of these.

Generating masks

Now you have a parse state (often a GSS), and you want a mask over LLM tokens.

A straightforward approach is:

For each LLM token tVt \in V, check whether appending it keeps the parse valid.

But you don’t append an LLM token to the parser. You append terminal symbols.

And a single LLM token is a byte string which might:

So “try token tt” is really “try every way tt’s bytes could advance the lexer into some terminal sequence, and then try advancing the parser for that terminal sequence.”

That mismatch is where a lot of “obvious” masking schemes get wrecked.

Example: segmentation explosions

Here’s a toy example that captures the shape of the problem.

Suppose your terminals include - and -- (or more generally, terminals that overlap). If you treat tokenization as “split into terminals in any way that works” (scannerless style), then a run of NN dashes has FN+1F_{N+1} possible segmentations (Fibonacci growth), because you’re tiling length NN using pieces of length 1 and 2.

And tokenizers absolutely contain long weird tokens. In cl200k there are tokens that are on the order of 100+ repeated punctuation characters. For N112N \approx 112, FN+1F_{N+1} is 1023\sim 10^{23} segmentations.

That’s for a single candidate token.

And there are ~200k tokens to check.

This is obviously not something you can do per decoding step.

Why tries/trellises still hurt in p99.9 land

A common direction is:

This can work well for many grammars and workloads. People implement this and get decent average performance, especially for “simple” grammars.

But for general CFG-ish constraints (especially when you care about worst-case), you run into two issues:

1) You end up doing real parser work “inside” masking

Even if your lexical side is deterministic (say you enforce longest-match, or you have a DFA lexer), a single grammar terminal shift can trigger a chain of reductions.

In an LR parser, “shift terminal tt” is often really:

  1. while top-of-stack wants to reduce, do some reductions (pop some states, push a goto state),
  2. then shift tt.

Those reduction chains are table-driven and fast in the happy path, but in GLR they can fan out. The operations are pointer-heavy and cache-unfriendly: you’re manipulating a GSS, merging nodes, pruning branches, etc.

If you do that inside masking—i.e. while traversing a trie of possible tokens—you’re effectively building and mutating a GSS for a hypothetical next token, and doing it potentially thousands of times per step.

That’s exactly the kind of p99.9 work you don’t want on the critical path.

2) You still pay for long/ugly tokens

Even with tries, there are tokens whose byte strings correspond to long sequences of terminals (think: long runs of punctuation, or tokens containing lots of structural characters plus whitespace/newlines). Those paths can survive pruning for a while, so you end up doing a lot of terminal-level processing per mask computation.

And yes, long ugly sequences can be syntactically valid in real languages. Python, JS, etc. happily accept monstrosities like:

------------------1

That means “a lot of unary minus operators.” It’s valid, and it forces the parser to do real work on a long terminal sequence.

You can transform the grammar to remove unit/null reductions. You can aggressively merge equivalent GSS nodes. You can add clever memoization.

I tried hard to make the “trie + incremental parse simulation” approach behave well in worst-case latency terms. In my experience, it’s a dead end if you’re aiming for predictable sub-millisecond masking on arbitrary inputs/grammars.

Invert the problem

Token validity depends on what’s on the stack.

Instead of asking:

“does LLM token tt work on this stack?” (for each tt)

ask:

“given this stack, which LLM tokens work?”

That’s the reframing.

Now we want a data structure that turns:

stack → allowed-token bitset

into something we can execute quickly.

That’s where the weighted automaton comes in.

Weighted automata over parser states

Think of a finite automaton, except:

In the automata I’m talking about:

So it’s like “run a DFA on the stack”, but with “which tokens survive” flowing through it.

The automaton reads a representation of your current parse configuration:

A token is valid if it’s valid on any parse path in the GSS, so on a GSS you just union the results across paths.

Another way to see it (why weights help)

Another mental model that helped me:

  1. Fix an LLM token tt.
  2. Consider all the ways its bytes could be tokenized into grammar terminals (given incremental lexing, overlapping terminals, etc.).
  3. For each such terminal sequence, consider all stack configurations from which that sequence can be legally parsed (shift/reduce/goto) without error.

That set of stacks is a regular language over parser state IDs (this is closely related to the classical “viable prefixes are regular” result from LR theory).

So you can imagine an automaton AtA_t that recognizes “stacks on which token tt is valid.”

Of course, if you do that for every tVt \in V, you’d have 200k automata. That’s useless at runtime.

A weighted automaton is how you smash those 200k membership tests into one run:

Same computation, but vectorized across tokens via bitsets.

A nice optimization: stop reading the stack early

If you implement the obvious “read stack symbols until you hit bottom” approach, you’ll still sometimes do more work than needed: often, after reading a small suffix of the stack, the set of valid tokens is already determined and deeper stack symbols can’t change it.

You can bake this into the automaton:

Operationally, it’s the difference between:

So you can stop when there’s nothing left in flight.

I’m not going to cite code here, but conceptually it’s just a precomputed fixed-point thing: for each automaton state, compute which tokens are guaranteed from this point (“final weight”), subtract them from the weights you propagate forward, and you get a monotone decreasing set that hits empty quickly.

Runtime sketch (single stack)

For a single LR stack, the runtime looks like:

def get_mask(stack_state_ids_top_to_bottom):
    # frontier: map automaton_state -> bitset(tokens)
    frontier = { A.start: ALL_TOKENS }

    decided = EMPTY  # tokens already known valid from suffix read so far

    for sid in stack_state_ids_top_to_bottom:
        new = {}
        for a_state, tokens in frontier.items():
            for (a2, weight) in A.step(a_state, sid):
                tokens2 = tokens & weight                 # intersection (filter)
                if tokens2.any():
                    new[a2] = new.get(a2, EMPTY) | tokens2  # union (merge paths)
        frontier = new

        # Optionally accumulate “final/decided” contribution here and
        # remove it from frontier tokens, so frontier only contains
        # “still-undecided” tokens.

        if not frontier:               # nothing left that depends on deeper stack
            break

    return decided | combine_accepting(frontier)

This is the key operational point:

Runtime on a GSS

For a GSS, you do basically the same computation, but over a graph rather than a single list.

Conceptually:

A token is valid if it’s valid on any stack path, so whenever GSS paths merge you union their token sets, and whenever automaton paths merge you union there too.

The important part is: it’s still the same “bitset flows through a graph via ∩ and ∪” pattern. No backtracking, no “try token tt” loop.

So where does the automaton come from?

Up to now I’ve treated “the weighted automaton” as magic.

It’s not magic, it’s just precomputation.

At a high level, you want an automaton that answers:

given the current lexer+parser state (represented by the stack/GSS), which LLM tokens could lead to a valid continuation?

There are two distinct problems mixed together:

  1. Lexical: what grammar terminals could a given LLM token’s bytes produce (given where we are in the lexer)?
  2. Syntactic: if we fed those terminals to the LR parser, would they be legal given the current stack?

The thing I compile is basically a composition of two automata:

I’ll keep this section high-level (the details are where the compile-time pain is), but the decomposition matters because it explains why the runtime is so simple.

1) Token → terminals (Terminal DWA)

An LLM token is a byte string.

A grammar tokenizer/lexer consumes a stream of bytes and emits grammar terminals. Crucially:

So the mapping “token vv → terminal sequence” is not a single fixed lookup. It depends on the current lexer state.

The Terminal DWA is the precomputed structure that answers:

from lexer state xx, which LLM tokens vv can produce which terminal tt next (and what lexer state do we end up in)?

A practical way to build it is:

Then determinize.

The key thing the Terminal DWA buys you is: it collapses “iterate over 200k tokens and run the lexer” into “traverse a small automaton state space and get bitsets.”

2) Terminal → stack effect (Template automata)

Now suppose the lexer says “the next terminal is tt.” What does the LR parser do?

You can precompute, for each terminal tt, an automaton that reads stack state IDs (top down) and represents all the ways the parser could legally process tt, including the reduction chain.

Internally I represent the net stack transformation using a push/pop algebra (the polycyclic monoid is the clean formalism), but you don’t need to care about that to get the intuition:

This is the piece that takes “pointer-heavy GLR reduction simulation” off the runtime path and turns it into precomputed transitions.

3) Compose them into a Parser DWA

Finally, you compose:

to get one deterministic weighted automaton that:

This is the automaton you run in get_mask.

One subtle but very useful algebraic fact during composition: when you concatenate stack-effect templates, you get adjacent “push then pop” pairs that cancel. That lets you simplify aggressively while building the composed automaton (and prune impossible paths when pushes/pops mismatch).

Also: at runtime I don’t actually need the automaton to update the stack. commit is implemented by the real lexer+parser. get_mask just needs to know whether a token is valid, so the automaton can discard some “push” bookkeeping once it has done its filtering job.

Why you only read a bounded suffix of the stack

A question I always get is: “do you really scan the entire stack every time?”

No. In practice you only need a bounded suffix, and you can make that precise.

Two related reasons:

In implementation terms: masks stabilize quickly, and you can early-exit.

Tokenization ambiguity (and longest-match) in streaming generation

The term “token” is overloaded:

These tokenizations don’t align.

Even if your lexer uses longest-match, longest-match is inherently forward-looking: you can’t know a match is final until you see what comes next.

Classic example:

In a streaming setting, after you see the first "+", you’re in a state where:

So you have to represent that uncertainty somehow.

The approach I use is what I call inhibited terminals (this is just a convenient operational trick, not a deep theory term):

This plays nicely with GLR+GSS, because “fork and prune” is already the parser’s native move.

Practically, the constraint state you carry around is not just “parser stack(s)”; it’s “(parser stack(s), lexer state(s))”. get_mask needs to condition on both, which is why the Parser DWA has multiple initial states (one per active lexer state).

Putting it together

At runtime you now have two different kinds of work, with very different performance profiles:

commit(state, token) -> state

This is the inherently “CFG-y” part. It can still have pathological cases (that’s life), but you at least get decades of parsing engineering behind it.

get_mask(state) -> mask

This is the performance-critical part, and it’s exactly where precomputation buys you predictability.

Why this feels close to optimal

I’m deliberately avoiding “this is optimal” as a theorem. Parsing isn’t a domain where you get many satisfying worst-case optimality results that also match real-world constraints.

But it feels close to optimal for two reasons:

On the commit side: GLR/GSS is a strong local optimum

For incremental CFG parsing with ambiguity:

…is a pretty hard baseline to beat. There’s a reason people working on these problems keep converging on GLR-like techniques:

You still inherit GLR’s lack of comforting worst-case bounds in “fully adversarial ambiguity” settings. In theory, GLR can degrade badly on highly ambiguous grammars/inputs. In practice, for “programming-language-ish” grammars with reasonable disambiguation, it behaves well.

On the mask side: “read once, never backtrack” is what you want

The weighted-automaton mask computation does the minimum work you’d reasonably hope for:

In other words: the runtime work scales with the size of the parse configuration, not with vocabulary size and not with “how gnarly are the tokens.”

Fast run, slow compile

The cost shifts to compile time.

Precompiling the grammar into this Parser DWA involves determinization and simplification in a semiring-ish setting (bitset weights, union at merges, intersection along paths). If you do it naively, large grammars can blow up in memory/time.

Getting compile-time and memory to behave took most of my engineering effort:

That’s probably the next post.


If you’re building constrained decoding and you care about p99.9 latency, my main takeaway is:

That inversion is the whole trick.